--- Instantiation of functions at their required types
module frege.compiler.gen.java.Instantiation where

import frege.Prelude hiding(<+>)

import Lib.PP(text, <+>, <>, <+>, <+/>)
import Data.TreeMap(member)

import Compiler.enums.Flags(TRACEG)

import Compiler.classes.Nice(nice, nicer)

import Compiler.instances.Nicer(nicectx, nicerctx)

import Compiler.types.Positions(Position)
import Compiler.types.Global(Global, Symbol, StG, SymInfo8, getST)
import Compiler.types.Types(Ctx, Context, TauT, SigmaT, RhoT, Tau, Sigma, Rho, pSigma)
-- import Compiler.types.Expression(Expr)
import Compiler.types.QNames
import Compiler.types.Symbols
import Compiler.types.JNames(JName)
import Compiler.types.AbstractJava

import Compiler.common.Errors as E()
import Compiler.common.Types as CT
import Compiler.common.JavaName

import Compiler.tc.Util(impliesG, reducedCtx)
import Compiler.Utilities as U()

import Compiler.gen.java.Common
import Compiler.gen.java.Bindings


envCtxs g = [ ctx | s <- reverse (Global.genEnv g),
                                     -- not (null (Symbol.typ s).bound),
                                     ctx <- (Symbol.typ s).rho.context ]

--- takes a list of contexts and returns the ones that are resolvable
resolvableCtxs ∷ Global → [Context] → [Context]
resolvableCtxs g ctxs = [ ctx | ctx <- envCtxs g, tctx <- ctxs, impliesG g ctx tctx ]


resolveConstraint ∷ Position → Context → StG JExpr
resolveConstraint pos (ctx@Ctx {cname, tau}) = do
         g <- getST
         E.logmsg TRACEG pos (text "resolveConstraint: " <+> text (nicerctx [ctx] g))
         if make then makeCtx else findCtx
     where
         tauflat = tau.flat
         tcon  = head tauflat 
         make | TCon {name} <- tcon = true
              | otherwise = false
 
         findCtx | Meta tv <- tau, tv.isFlexi = do
                         g <- getST
                         E.fatal pos (text ("unknown context: " ++ nice cname g ++ " " ++ nice tau g))
                         pure (JAtom "null")
                 | TApp a b ← tau,
                    isSpecialClassName cname = do
                        g   ←  getST
                        it  ←  resolveConstraint pos ctx.{tau = a}
                        let mem     =  JX.staticMember (JName "PreludeList" "kindedCtx")
                            gmem    =  mem.{targs=map boxed [tauJT g a, tauJT g b]}
                        pure (JInvoke gmem [it])
                 | otherwise = do
             g <- getST
             let
                 ctxsnms = zip (envCtxs g) allCtxNames
                 implies = impliesG g
             E.logmsg TRACEG pos (text ("findCtx: looking for  " ++ nice cname g ++ " " ++ nice tau g))
             E.logmsg TRACEG pos (text ("findCtx: we have  " ++ nicectx (envCtxs g) g))
             let ok = [ name | (ctx1, name) <- ctxsnms, ctx1 `implies` ctx]
             E.logmsg TRACEG pos (text ("findCtx: ok= " ++ show ok))
             if (null ok)
                 then do 
                    E.error pos (text ("FATAL: Cant find context for " 
                            ++ nice cname g ++ " " ++ nice tau g))
                    E.error pos (text ("This is a compiler error. Sorry."))
                    pure (JAtom "UNKNOWN_CONTEXT")
                 else pure ((JAtom • head) ok)
         makeCtx  = do
             csym <- U.findC cname
             let special = isSpecialClassName cname
             case tcon of
                 TCon {name} -> case filter ((name ==) • fst) csym.insts of
                     (_,iname):_ -> do
                         inst <- U.findI iname
                         g <- getST
                         let crho = RhoTau [] tau
                             csig = ForAll [] crho
                         E.logmsg TRACEG pos (text ("makeCtx: unify " ++ nice inst.typ g ++ "  with  "
                                 ++ nice csig g))
                         let taujt = tauJT g tau
                         when (isArrayClassName cname && (isJust . isPrimitive) taujt) do
                            E.error pos (
                                    text "Implementation restriction: generic use of "
                                    <+> text (show taujt) <> text "[]"
                                    <+> text " is unfortunately impossible." 
                                )
                         let tree = unifySigma g inst.typ csig
                             rho  = substRho tree inst.typ.rho
                             gargs = map (boxed . tauJT g . substTau tree) 
                                        (filter ((`member` tree) . _.var) 
                                            inst.typ.tvars)
                         -- rhojt <- rhoJT rho
                         E.logmsg TRACEG pos (text ("makeCtx substituted: " ++ nice rho g))
                         -- let subctx = map (TC.reducedCtx g) rho.context
                         args <- mapM (resolveConstraint pos) rho.context
                         let jiname = symJavaName g inst
                         let jit    = Constr jiname gargs              -- jitjts
                             jex
                                | special, null args    = JInvoke (JX.static "mk" jit).{targs=[boxed taujt]} []
                                | null args, null gargs = JX.static "it" jit 
                                | null args             = JInvoke (JX.static "mk" jit) [] 
                                | otherwise             = JNew jit args
                                --where 
                                --    mem     =  JX.staticMember (JName "PreludeList" "kindedCtx")
                         E.logmsg TRACEG pos (text ("makeCtx: " ++ showJex jex))
                         pure jex
                     [] -> do
                         g <- getST
                         E.fatal pos (text ("makeCtx: instance " ++ nice cname g ++ " " ++ nice tau g ++ " not found."))
                 other -> do
                     g <- getST
                     E.fatal pos (text ("makeCtx: head is " ++ nice other g))



{--
    Instantiate a pattern bound symbol at a given type.
    This is interesting only for polymorphic functions that have a @forall@
    type and constraints, like
    
    > f :: forall e.Num e => [e] -> [e] @@ Num a => [a] -> [a]

    Here, we need to apply the 'Num' constraint to @f@
    to get a @Func<Lazy<List<A>>, List<A>>@
    
    We return:
        (Func<Lazy<List<A>>, List<A>>)(Lazy<List<A>> arg$n -> 
            ((Func<Num<A>, Lazy<List<A>>, List<A>>)f).apply(ctx1, arg$n);
    Or, in the case of CAFS:
        ((Func<Num<A>, List<A>)f).apply(ctx1)
    -}
instPatternBound :: Position -> Binding -> Sigma -> StG Binding
instPatternBound pos bindns sigma = do
    g <- getST
    E.logmsg TRACEG pos (text ("instPatternBound: "
                             ++ show bindns 
                             ++ " @@ " ++ nice sigma g))
    let bind  = strictBind g bindns   -- make sure we don't hit a Lazy<Lambda>
        varjt = lambdaType (sigmaJT g sigma)       -- Func<....>
        casted = convertHigher bind.jex varjt
        higherCtx Ctx{tau=TVar{var}} = var `elem` sigma.vars
        higherCtx _ = false
        contexts = [ ctx | ctx ← sigma.rho.context, not (higherCtx ctx) ]
        hctxs    = filter higherCtx sigma.rho.context

    ctxs ← mapM (resolveConstraint pos) contexts
    let jctxs   =  map (ctxJT g) contexts
        thctxs  =  zipWith (\c jc -> JInvoke 
                        (JX.static "lazy" (inThunk jc))
                        [c]) ctxs jctxs
    if (null ctxs)
        then pure (newBind g sigma casted).{jtype = varjt}
        else pure (newBind g sigma.{rho ← _.{context=hctxs}} casted).{
                    jex ← JX.invoke thctxs . JX.xmem "apply",
                    jtype ← lazy
                }

